import argparse
import os
from functools import partial

import torch
import torch.distributed as dist
import  numpy as np
import yaml
from metric import KNN, LinearProbe
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from ema_pytorch import EMA
from model.SODA import SODA
from model.pointnet import PointNetEncoder
from model.decoder import UNet_decoder
from model.encoder import Network
from utils import Config, get_optimizer, init_seeds, reduce_tensor, DataLoaderDDP,DataLoaderNonDDP, print0
#from model.vision_3d.positional_encoding import PositionalEncoding
from model.vision_3d.pointnet2_new_encoder import PointNet2EncoderXYZ
#from model.vision_3d.pointnext_encoder import PointNextEncoderXYZ
from model.vision_3d.pointnet_extractor import PointNetEncoderXYZ
from model.vision_3d.pointnet_extractor import PointTransformer

#from model.vision_3d.pointtransformer_encoder import Backbone as PT_Backbone
from datasets import get_dataset,ShapeNetCore
from dataset.realdex_dataset import RealDexDataset
from dataset.metaworld_dataset import MetaworldDataset
#from pytorch3d.loss import chamfer_distance
import torch
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def train(opt):

    yaml_path = opt.config
    local_rank = opt.local_rank
    use_amp = opt.use_amp

    with open(yaml_path, 'r') as f:
        opt = yaml.full_load(f)
    print0(opt)
    opt = Config(opt)
    model_dir = os.path.join(opt.save_dir, "ckpts")
    vis_dir = os.path.join(opt.save_dir, "visual")
    tsbd_dir = os.path.join(opt.save_dir, "tensorboard")
    if local_rank == 0:
        os.makedirs(model_dir, exist_ok=True)
        os.makedirs(vis_dir, exist_ok=True)

    device = "cuda:0" #% local_rank
    

    soda = SODA(encoder= PointNetEncoderXYZ(),#**opt.encoder),Network(**opt.encoder),
                decoder=UNet_decoder(**opt.decoder),
                **opt.diffusion,
                device=device)
    soda.to(device)




    if local_rank == 0:
        ema = EMA(soda, beta=opt.ema, update_after_step=0, update_every=1)
        ema.to(device)
        ema.eval()
        writer = SummaryWriter(log_dir=tsbd_dir)

    train = MetaworldDataset(zarr_path = '/home/fyk/CordViP/policy/3D-Diffusion-Policy/3D-Diffusion-Policy/data/reach_chicken_50.zarr',
            horizon=2,
            pad_before=1,
            pad_after=0,
            seed=42,
            val_ratio=0.0,
            max_train_episodes=90)
    
    #import pdb; pdb.set_trace()

    normalizer = train.get_normalizer()
    
    train_loader = DataLoaderNonDDP(train,
                                          batch_size=opt.batch_size,
                                          shuffle=True)
    

    lr = opt.lrate
  
    soda.set_normalizer(normalizer)
    optim = get_optimizer([{'params': soda.parameters(), 'lr': lr * opt.lrate_ratio},
                           ], opt, lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    if opt.load_epoch != -1:
        target = os.path.join(model_dir, f"model_{opt.load_epoch}.pth")
        print0("loading model at", target)
        checkpoint = torch.load(target, map_location=device)
        soda.load_state_dict(checkpoint['MODEL'])
        if local_rank == 0:
            ema.load_state_dict(checkpoint['EMA'])
        optim.load_state_dict(checkpoint['opt'])
    max = 100
    soda.to(device)
    for ep in range(opt.load_epoch + 1, opt.n_epoch):

        optim.param_groups[0]['lr'] = lr * min((ep + 1.0) / opt.warm_epoch, 1.0) # warmup
        optim.param_groups[0]['lr'] = optim.param_groups[0]['lr'] * opt.lrate_ratio
      
        soda.train()
        if local_rank == 0:
            enc_lr = optim.param_groups[0]['lr']
            dec_lr = optim.param_groups[0]['lr']
            print(f'epoch {ep}, lr {enc_lr:f} & {dec_lr:f}')
            loss_ema = None
            pbar = tqdm(train_loader)
        else:
            pbar = train_loader
        i = 0 
        for source in pbar:
            optim.zero_grad()
        
            pointcloud = source['obs']
            pointcloud['point_cloud'] = pointcloud['point_cloud'].to(device)
            pointcloud['agent_pos'] = pointcloud['agent_pos'].to(device)
            action = source["action"]
            action = action.to(device) 
            
            loss=soda(pointcloud, action,use_amp=False)
            scaler.scale(loss).backward()
            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(parameters=soda.parameters(), max_norm=opt.grad_clip_norm)
            scaler.step(optim)
            scaler.update()


            if local_rank == 0:
                ema.update()
                if loss_ema is None:
                    loss_ema = loss.item()
                else:
                    loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
                pbar.set_description(f"loss: {loss_ema:.4f}")
            i = i+1
        """
        if (ep+1)%1 ==0  and ep>1:
           gen = []
           target = []
           i= 0
           with torch.no_grad():
                for source in pbar:
                    #pointcloud = source['obs']['point_cloud']
                    x_target = source['pointcloud'] 
                    x_target = x_target.to(device)
                    #x_target  = pointcloud[:,1,:,0:3]
                    #x_past =  pointcloud[:,0,:,0:3]
                    #x_future = pointcloud[:,2,:,0:3]
                    x_past = x_target

                    x_future =target_loader[i]['pointcloud'] 
                    x_future = x_future.to(device)
                    x_gen = soda.ddim_sample(x_target.shape[0], x_target.shape[1:],x_past,x_future)
                    gen.append(x_gen)
                    target.append(x_future)
                    break
           gen = torch.stack(gen, dim=0)
           target = torch.stack(target, dim=0)
           np.save('yangben/chairairplane/comben_gen_'+str(ep)+'.npy', gen.cpu().numpy())
           np.save('yangben/chairairplane/comben_target_'+str(ep)+'.npy', target.cpu().numpy())   
       
        #if  chamfer[0] <max:
        #    np.save('x_gen.npy', visual[0].detach().cpu().numpy())
        #    np.save('x_target.npy', visual[1].detach().cpu().numpy())
        #    max = chamfer[0]
            
        """    
        print("epoch", ep, "loss", loss_ema)
        checkpoint = {
               'MODEL': soda.encoder.state_dict(),}
        save_path = os.path.join(model_dir, f"model_last.pth")
        torch.save(checkpoint, save_path)
        #print('saved model at', save_path)
        '''
        if local_rank == 0:
            writer.add_scalar('lr/enc', enc_lr, ep)
            writer.add_scalar('lr/dec', dec_lr, ep)
            writer.add_scalar('loss', loss_ema, ep)
            #print("testing\n")
            if (opt.save_per != 0 and ep % opt.save_per == 0) or ep == opt.n_epoch - 1:
                pass
            else:
                continue
            checkpoint = {
                'MODEL': soda.state_dict(),
                'opt': optim.state_dict(),
            }
            save_path = os.path.join(model_dir, f"model_last.pth")
            # save_image(grid, save_path)

            torch.save(checkpoint, save_path)
            print('saved model at', save_path)
        '''

if __name__ == "__main__":

    #######################################
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str)
    parser.add_argument('--local_rank', default=0, type=int,
                        help='node rank for distributed training')
    parser.add_argument("--use_amp", action='store_true', default=False)
    opt = parser.parse_args()
    print0(opt)

    init_seeds(no=opt.local_rank)
    # dist.init_process_group(backend='nccl')
    torch.cuda.set_device(opt.local_rank)
    train(opt)